from time import time

import taichi as ti
from numpy import arange
from numpy.random import shuffle

from .common import Vec2, fade, mix

# def fade1(x: float): return (3 - x * 2) * x * x


@ti.data_oriented
class Perlin2D:
    """二维柏林噪声生成器"""

    def __init__(self, size: int):
        self.size: int = size
        self.gradients = Vec2.field(shape=(size,))
        self.permutations = ti.field(int, (size,))
        self.img = ti.field(float, (self.size, self.size))
        self._init()
        permutation = arange(self.size, dtype=int)
        shuffle(permutation)

        self.permutations.from_numpy(permutation)

    @ti.kernel
    def _init(self):
        for i in self.gradients:
            self.gradients[i] = (Vec2(ti.random(), ti.random()) * 2 - 1).normalized()

    @ti.func
    def getIndex(self, x: int, y: int) -> int:
        return self.permutations[(self.permutations[y % self.size] + x) % self.size]

    @ti.func
    def perlinNoise(self, vecP: Vec2) -> float:
        q0 = int(vecP)
        q1 = q0 + 1
        l0, l1 = vecP - q0, vecP - q1
        q00 = self.gradients[self.getIndex(q0.x, q0.y)]
        q10 = self.gradients[self.getIndex(q1.x, q0.y)]
        q01 = self.gradients[self.getIndex(q0.x, q1.y)]
        q11 = self.gradients[self.getIndex(q1.x, q1.y)]
        v00 = q00.dot(l0)
        v11 = q11.dot(l1)
        v10 = q10.x * l1.x + q10.y * l0.y
        v01 = q01.x * l0.x + q01.y * l1.y
        l3 = fade(l0)
        x0 = mix(v00, v01, l3.y)
        x1 = mix(v10, v11, l3.y)
        return mix(x0, x1, l3.x)

    @ti.func
    def FBM(self, vecP: Vec2, step: int = 1) -> float:
        """分形布朗运动 - Fractal Brown Motion"""
        value = 0.0
        amplitude = 0.5 / (1 - 0.5**step)
        frequency = 1
        for _ in range(step):
            value += amplitude * self.perlinNoise(vecP * frequency)
            frequency *= 2
            amplitude *= 0.5
        return value

    @ti.kernel
    def genImage(self, scale: float):
        for i in ti.grouped(self.img):
            self.img[i] = self.FBM(i / scale, 21)
        min_, max_ = 1.0, -1.0
        for i, j in self.img:
            min_ = min(min_, self.img[i, j])
            max_ = max(max_, self.img[i, j])
        max_ -= min_
        for i, j in self.img:
            self.img[i, j] -= min_
            if max_ > 0:
                self.img[i, j] /= max_

    def to_numpy(self):
        return self.img.to_numpy()


def wood(x, y, a):
    return a * 10 - int(a * 10)


def marble(x, *_):
    return ti.sin(x) * 0.5 + 0.5


def text(x, y, a):
    return ti.sin(ti.sin(x) + ti.sin(y) + 10 * ti.sin(a)) * 0.5 + 0.5


if __name__ == "__main__":
    ti.init(arch=ti.cuda)
    size = 1 << 10
    img = Perlin2D(size)
    img.genImage(16)
    ti.tools.imshow(img)
